Description of variables
# import library
import pandas as pd
import seaborn as sns
# import data
attrition = pd.read_csv('https://github.com/ybifoundation/Dataset/raw/main/EmployeeAttrition.csv')
# view data
attrition.head()
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | Yes | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 1 | 1 | ... | 1 | 80 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
| 1 | 49 | No | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 1 | 2 | ... | 4 | 80 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
| 2 | 37 | Yes | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 1 | 4 | ... | 2 | 80 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
| 3 | 33 | No | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 1 | 5 | ... | 3 | 80 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
| 4 | 27 | No | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | 7 | ... | 4 | 80 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
5 rows × 35 columns
# info of data
attrition.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1470 entries, 0 to 1469 Data columns (total 35 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 1470 non-null int64 1 Attrition 1470 non-null object 2 BusinessTravel 1470 non-null object 3 DailyRate 1470 non-null int64 4 Department 1470 non-null object 5 DistanceFromHome 1470 non-null int64 6 Education 1470 non-null int64 7 EducationField 1470 non-null object 8 EmployeeCount 1470 non-null int64 9 EmployeeNumber 1470 non-null int64 10 EnvironmentSatisfaction 1470 non-null int64 11 Gender 1470 non-null object 12 HourlyRate 1470 non-null int64 13 JobInvolvement 1470 non-null int64 14 JobLevel 1470 non-null int64 15 JobRole 1470 non-null object 16 JobSatisfaction 1470 non-null int64 17 MaritalStatus 1470 non-null object 18 MonthlyIncome 1470 non-null int64 19 MonthlyRate 1470 non-null int64 20 NumCompaniesWorked 1470 non-null int64 21 Over18 1470 non-null object 22 OverTime 1470 non-null object 23 PercentSalaryHike 1470 non-null int64 24 PerformanceRating 1470 non-null int64 25 RelationshipSatisfaction 1470 non-null int64 26 StandardHours 1470 non-null int64 27 StockOptionLevel 1470 non-null int64 28 TotalWorkingYears 1470 non-null int64 29 TrainingTimesLastYear 1470 non-null int64 30 WorkLifeBalance 1470 non-null int64 31 YearsAtCompany 1470 non-null int64 32 YearsInCurrentRole 1470 non-null int64 33 YearsSinceLastPromotion 1470 non-null int64 34 YearsWithCurrManager 1470 non-null int64 dtypes: int64(26), object(9) memory usage: 350.3+ KB
# summary statistics
attrition.describe()
| Age | DailyRate | DistanceFromHome | Education | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | HourlyRate | JobInvolvement | JobLevel | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.0 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | ... | 1470.000000 | 1470.0 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 |
| mean | 36.923810 | 802.485714 | 9.192517 | 2.912925 | 1.0 | 1024.865306 | 2.721769 | 65.891156 | 2.729932 | 2.063946 | ... | 2.712245 | 80.0 | 0.793878 | 11.279592 | 2.799320 | 2.761224 | 7.008163 | 4.229252 | 2.187755 | 4.123129 |
| std | 9.135373 | 403.509100 | 8.106864 | 1.024165 | 0.0 | 602.024335 | 1.093082 | 20.329428 | 0.711561 | 1.106940 | ... | 1.081209 | 0.0 | 0.852077 | 7.780782 | 1.289271 | 0.706476 | 6.126525 | 3.623137 | 3.222430 | 3.568136 |
| min | 18.000000 | 102.000000 | 1.000000 | 1.000000 | 1.0 | 1.000000 | 1.000000 | 30.000000 | 1.000000 | 1.000000 | ... | 1.000000 | 80.0 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 30.000000 | 465.000000 | 2.000000 | 2.000000 | 1.0 | 491.250000 | 2.000000 | 48.000000 | 2.000000 | 1.000000 | ... | 2.000000 | 80.0 | 0.000000 | 6.000000 | 2.000000 | 2.000000 | 3.000000 | 2.000000 | 0.000000 | 2.000000 |
| 50% | 36.000000 | 802.000000 | 7.000000 | 3.000000 | 1.0 | 1020.500000 | 3.000000 | 66.000000 | 3.000000 | 2.000000 | ... | 3.000000 | 80.0 | 1.000000 | 10.000000 | 3.000000 | 3.000000 | 5.000000 | 3.000000 | 1.000000 | 3.000000 |
| 75% | 43.000000 | 1157.000000 | 14.000000 | 4.000000 | 1.0 | 1555.750000 | 4.000000 | 83.750000 | 3.000000 | 3.000000 | ... | 4.000000 | 80.0 | 1.000000 | 15.000000 | 3.000000 | 3.000000 | 9.000000 | 7.000000 | 3.000000 | 7.000000 |
| max | 60.000000 | 1499.000000 | 29.000000 | 5.000000 | 1.0 | 2068.000000 | 4.000000 | 100.000000 | 4.000000 | 5.000000 | ... | 4.000000 | 80.0 | 3.000000 | 40.000000 | 6.000000 | 4.000000 | 40.000000 | 18.000000 | 15.000000 | 17.000000 |
8 rows × 26 columns
# check for missing value
attrition.isnull().sum()
Age 0 Attrition 0 BusinessTravel 0 DailyRate 0 Department 0 DistanceFromHome 0 Education 0 EducationField 0 EmployeeCount 0 EmployeeNumber 0 EnvironmentSatisfaction 0 Gender 0 HourlyRate 0 JobInvolvement 0 JobLevel 0 JobRole 0 JobSatisfaction 0 MaritalStatus 0 MonthlyIncome 0 MonthlyRate 0 NumCompaniesWorked 0 Over18 0 OverTime 0 PercentSalaryHike 0 PerformanceRating 0 RelationshipSatisfaction 0 StandardHours 0 StockOptionLevel 0 TotalWorkingYears 0 TrainingTimesLastYear 0 WorkLifeBalance 0 YearsAtCompany 0 YearsInCurrentRole 0 YearsSinceLastPromotion 0 YearsWithCurrManager 0 dtype: int64
# check for categories
attrition.dtypes
Age int64 Attrition object BusinessTravel object DailyRate int64 Department object DistanceFromHome int64 Education int64 EducationField object EmployeeCount int64 EmployeeNumber int64 EnvironmentSatisfaction int64 Gender object HourlyRate int64 JobInvolvement int64 JobLevel int64 JobRole object JobSatisfaction int64 MaritalStatus object MonthlyIncome int64 MonthlyRate int64 NumCompaniesWorked int64 Over18 object OverTime object PercentSalaryHike int64 PerformanceRating int64 RelationshipSatisfaction int64 StandardHours int64 StockOptionLevel int64 TotalWorkingYears int64 TrainingTimesLastYear int64 WorkLifeBalance int64 YearsAtCompany int64 YearsInCurrentRole int64 YearsSinceLastPromotion int64 YearsWithCurrManager int64 dtype: object
# visualize pairplot
sns.pairplot(attrition)
<seaborn.axisgrid.PairGrid at 0x715d3b8>
# columns name
attrition.columns
Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',
'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',
'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',
'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',
'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',
'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',
'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',
'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',
'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',
'YearsWithCurrManager'],
dtype='object')
# column names of numerical columns
attrition.columns
Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',
'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',
'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',
'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',
'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',
'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',
'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',
'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',
'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',
'YearsWithCurrManager'],
dtype='object')
# define y
y=attrition['Attrition']
# define X
X=[['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',
'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',
'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',
'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',
'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',
'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',
'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',
'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',
'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',
'YearsWithCurrManager']]
# split data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.7,random_state=2529)
# verify shape
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((1029, 26), (441, 26), (1029,), (441,))
# select model
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
# train model
model.fit(X_train, y_train)
RandomForestClassifier()
# predict with model
y_pred = model.predict(X_test)
# model accuracy
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)
0.8571428571428571
# model confusion matrix
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_pred)
array([[369, 5],
[ 58, 9]])
# model classification report
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred))
precision recall f1-score support
No 0.86 0.99 0.92 374
Yes 0.64 0.13 0.22 67
accuracy 0.86 441
macro avg 0.75 0.56 0.57 441
weighted avg 0.83 0.86 0.82 441
# future prediction
X_new=X_sample()
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1052 | 30 | No | Non-Travel | 990 | Research & Development | 7 | 3 | Technical Degree | 1 | 1482 | ... | 2 | 80 | 2 | 1 | 2 | 2 | 1 | 0 | 0 | 0 |
1 rows × 35 columns
# define X_new
X_new
| Age | DailyRate | DistanceFromHome | Education | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | HourlyRate | JobInvolvement | JobLevel | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1052 | 30 | 990 | 7 | 3 | 1 | 1482 | 3 | 64 | 3 | 1 | ... | 2 | 80 | 2 | 1 | 2 | 2 | 1 | 0 | 0 | 0 |
1 rows × 26 columns
# predict for X_new
model.predict(X_new)
array(['No'], dtype=object)